{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "import yaml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_yaml(yaml_path: str) -> dict:\n",
    "    \"\"\"通过id找到名称\n",
    "\n",
    "    Args:\n",
    "        yaml_path (str): yaml文件路径\n",
    "\n",
    "    Returns:\n",
    "        yaml (dict)\n",
    "    \"\"\"\n",
    "    with open(yaml_path, 'r', encoding='utf-8') as f:\n",
    "        y = yaml.load(f, Loader=yaml.FullLoader)\n",
    "\n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: 0,\n",
       " 1: 23,\n",
       " 2: 23,\n",
       " 3: 23,\n",
       " 4: 20,\n",
       " 5: 20,\n",
       " 6: 21,\n",
       " 7: 21,\n",
       " 8: 14,\n",
       " 9: 14,\n",
       " 10: 18,\n",
       " 11: 11,\n",
       " 12: 0,\n",
       " 13: 0,\n",
       " 14: 0,\n",
       " 15: 0,\n",
       " 16: 0,\n",
       " 17: 23,\n",
       " 18: 0,\n",
       " 19: 23}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "remap_dict: dict = load_yaml(r\"../weights/yolov5m-sgd-e300-bg1000/remap.yaml\")\n",
    "remap_dict[\"remap\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {\n",
    "    'detect': [\n",
    "        {'class_index': 10, 'class': 'mian_ya', 'confidence': 0.935646653175354, 'box': [0, 1011, 546, 1562]},\n",
    "        {'class_index': 10, 'class': 'mian_ya', 'confidence': 0.8366711735725403, 'box': [154, 539, 442, 1117]},\n",
    "        {'class_index': 10, 'class': 'mian_ya', 'confidence': 0.7807613611221313, 'box': [290, 1481, 1131, 2015]},\n",
    "        {'class_index': 10, 'class': 'mian_ya', 'confidence': 0.4629181921482086, 'box': [427, 374, 983, 1322]},\n",
    "        {'class_index': 14, 'class': 'piao_chong_yong', 'confidence': 0.6198558211326599, 'box': [821, 346, 838, 380]},\n",
    "        {'class_index': 15, 'class': 'piao_chong_you_chong', 'confidence': 0.583463191986084, 'box': [603, 652, 618, 675]},\n",
    "        {'class_index': 14, 'class': 'piao_chong_yong', 'confidence': 0.663463191986084, 'box': [500, 449, 620, 675]}\n",
    "        ],\n",
    "    'count': {10: 4, 14: 1, 15: 1},\n",
    "    'image_size': (2016, 1134, 3)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remap(data: dict, remap_dict: dict) -> dict:\n",
    "    \"\"\"将预测id转换为数据库id,删除不需要的数据\n",
    "\n",
    "    Args:\n",
    "        data (dict):       预测数据\n",
    "        remap_dict (dict): origin id to database id dict\n",
    "\n",
    "    Returns:\n",
    "        dict:              remap的数据\n",
    "    \"\"\"\n",
    "    new_count = []\n",
    "    new_box   = []\n",
    "    for box in data[\"detect\"]:\n",
    "        # 去除名字\n",
    "        box.pop(\"class\")\n",
    "        # 映射id\n",
    "        box[\"class_index\"] = remap_dict[box[\"class_index\"]]\n",
    "\n",
    "        # 忽略new_id为0的类别,这个类别不要\n",
    "        if box[\"class_index\"] == 0:\n",
    "            # 不能直接remove,会导致训练数据长度变化,后面一个的数据读取不到\n",
    "            continue\n",
    "\n",
    "        new_box.append(box)\n",
    "        new_count.append(box[\"class_index\"])\n",
    "\n",
    "    # reamp id 计数\n",
    "    data[\"detect\"] = new_box\n",
    "    data[\"count\"] = dict(Counter(new_count))\n",
    "\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'detect': [{'class_index': 18,\n",
       "   'confidence': 0.935646653175354,\n",
       "   'box': [0, 1011, 546, 1562]},\n",
       "  {'class_index': 18,\n",
       "   'confidence': 0.8366711735725403,\n",
       "   'box': [154, 539, 442, 1117]},\n",
       "  {'class_index': 18,\n",
       "   'confidence': 0.7807613611221313,\n",
       "   'box': [290, 1481, 1131, 2015]},\n",
       "  {'class_index': 18,\n",
       "   'confidence': 0.4629181921482086,\n",
       "   'box': [427, 374, 983, 1322]}],\n",
       " 'count': {18: 4},\n",
       " 'image_size': (2016, 1134, 3)}"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "remap(data, remap_dict[\"remap\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 测试python动态删除数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 A\n",
      "5\n",
      "1 B\n",
      "5\n",
      "2 C\n",
      "3 E\n",
      "4\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'name': 'A'}, {'name': 'B'}, {'name': 'E'}]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_data = [\n",
    "    {\"name\": \"A\"},\n",
    "    {\"name\": \"B\"},\n",
    "    {\"name\": \"C\"},\n",
    "    {\"name\": \"D\"},\n",
    "    {\"name\": \"E\"},\n",
    "]\n",
    "\n",
    "new_test_data = []\n",
    "for i, d in enumerate(test_data):\n",
    "    print(i, d[\"name\"])\n",
    "    if d[\"name\"] == \"C\":\n",
    "        test_data.remove(d)\n",
    "        continue\n",
    "    new_test_data.append(d)\n",
    "    print(len(test_data))\n",
    "# name == c 时\n",
    "# 会pop掉自己,长度变短,但是索引累加,会导致跳过数据\n",
    "# ABCDE index=2 => C\n",
    "# index += 1\n",
    "# ABDE  index=3 => E 这一次就是E了,直接跳过了D,动态删除C会导致下一个D被忽略\n",
    "new_test_data # 可以看到虽然只删除了C,D也被跳过了"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 解决办法1: 倒序删除,删除一个数据,其他数据向前移动,不影响前面的数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 E\n",
      "5\n",
      "1 D\n",
      "5\n",
      "2 C\n",
      "3 B\n",
      "4\n",
      "4 A\n",
      "4\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'name': 'A'}, {'name': 'B'}, {'name': 'D'}, {'name': 'E'}]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "test_data = [\n",
    "    {\"name\": \"A\"},\n",
    "    {\"name\": \"B\"},\n",
    "    {\"name\": \"C\"},\n",
    "    {\"name\": \"D\"},\n",
    "    {\"name\": \"E\"},\n",
    "]\n",
    "\n",
    "new_test_data = []\n",
    "for i, d in enumerate(test_data[::-1]):\n",
    "    print(i, d[\"name\"])\n",
    "    if d[\"name\"] == \"C\":\n",
    "        test_data.remove(d)\n",
    "        continue\n",
    "    new_test_data.append(d)\n",
    "    print(len(test_data))\n",
    "new_test_data[::-1]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 解决办法2: 使用新数组存放新数据,不删除旧数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 A\n",
      "1 B\n",
      "2 C\n",
      "3 D\n",
      "4 E\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'name': 'A'}, {'name': 'B'}, {'name': 'D'}, {'name': 'E'}]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_data = [\n",
    "    {\"name\": \"A\"},\n",
    "    {\"name\": \"B\"},\n",
    "    {\"name\": \"C\"},\n",
    "    {\"name\": \"D\"},\n",
    "    {\"name\": \"E\"},\n",
    "]\n",
    "new_test_data = []\n",
    "for i, d in enumerate(test_data):\n",
    "    print(i, d[\"name\"])\n",
    "    if d[\"name\"] == \"C\":\n",
    "        continue\n",
    "    new_test_data.append(d)\n",
    "new_test_data"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
